from distributed_pcg.utils import read_dataset
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import sys
import cvxpy as cp



def create_data(n,d_list,l,seed,a):
    np.random.seed(seed)
    A_list = [] 
    dl_list = []
    for i in range(len(d_list)):
        A = np.random.normal(size=(n,d_list[i]))
        AtA  = A.T@A 
        U,_,Vt = np.linalg.svd(AtA)
        alpha = np.diag([(a)**i for i in range(d_list[i])])
        A = U@(alpha)**0.5@Vt
        AtA = A.T@A 
        eff_dim = np.trace((AtA/n)@np.linalg.inv((AtA/n)+l*np.eye(d_list[i])))
        A_list.append(A)
        dl_list.append(eff_dim)
    return A_list, dl_list
    

def compute_m(A_list,d_list,dl_list,n,l):
    m_list = []
    for i in range(len(A_list)):
        A = A_list[i]
        m = 2
        H = A.T@A/n
        while m<d_list[i]:
            S = np.random.normal(loc=0,scale=1/(m**0.5),size=(m,d_list[i]))
            z = -5*l/12
            Sm = np.trace(np.linalg.inv(S@H@S.T-z*np.eye(m)))/m
            if (Sm)>(1/l):
                break
            else: 
                m = 2*m
        m_list.append(m)

    return np.array(dl_list), 1.5*np.array(dl_list), 4*np.array(dl_list), m_list
    


def get_x_axis(data_set):
    x_max = -1
    length = 0
    for i in data_set: 
        if i[-1][0]>x_max:
            length = len(i)
            x_max = i[-1][0]
    if length == 2:
        return np.linspace(0, x_max, 2)
    return np.linspace(0, x_max, max(length,100))

def interpolate(data):
    numbers = np.zeros((len(data), len(data[0])))
    for i in range(len(data)):
        numbers[i] = np.array(data[i])
    mean = np.quantile(data, 0.5, axis=0)
    error_l = np.quantile(data, 0.2, axis=0)
    error_u = np.quantile(data, 0.8, axis=0)  
    return (mean, error_l, error_u) 


def plot_multi_realdata():
    n=100; d_list = [100,500,1000,5000,10000]; l=1e-3
    dl_list = []
    dl_a_list = [] 
    dl_b_list = [] 
    m_list = []
    for i in range(10):
        print(i)
        A_list, dl_list_1 = create_data(n,d_list,l,seed=i,a=0.99+i*1e-3)
        dl, dl_a, dl_b, m = compute_m(A_list,d_list,dl_list_1,n,l)
        dl_list.append(np.array(dl))
        dl_a_list.append(dl_a)
        dl_b_list.append(dl_b)
        m_list.append(m)

    dl_plot_data = interpolate(dl_list)
    dl_a_plot_data = interpolate(dl_a_list)
    dl_b_plot_data = interpolate(dl_b_list)
    m_plot_data = interpolate(m_list)
    plt.figure(figsize=(100, 100))
    fig, ax = plt.subplots()
    clrs = sns.color_palette("husl", 10)
    ax.plot(np.array(d_list),dl_plot_data[0], label='dl', c=clrs[7])
    ax.fill_between(np.array(d_list),dl_plot_data[1], dl_plot_data[2],alpha=0.3, facecolor=clrs[7])
    ax.plot(np.array(d_list),dl_a_plot_data[0], label='1.5dl', c=clrs[4])
    ax.fill_between(np.array(d_list),dl_a_plot_data[1], dl_a_plot_data[2],alpha=0.3, facecolor=clrs[4])
    ax.plot(np.array(d_list),dl_b_plot_data[0], label='4dl', c=clrs[1])
    ax.fill_between(np.array(d_list),dl_b_plot_data[1], dl_b_plot_data[2],alpha=0.3, facecolor=clrs[1])
    ax.plot(np.array(d_list),m_plot_data[0], label='m', c=clrs[2])
    ax.fill_between(np.array(d_list),m_plot_data[1], m_plot_data[2],alpha=0.3, facecolor=clrs[2])
    ax.legend(fontsize=22, loc="upper right")
    plt.xlabel('d', fontsize=20)
    plt.ylabel('sketch dimension', fontsize=20)
    plt.xticks(fontsize=17)
    plt.yticks(fontsize=17)
    plt.tight_layout()
    plt.savefig('synthetic_sketch_dimension.pdf')




if __name__ == '__main__':
    plot_multi_realdata()

